{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7660e1bc",
   "metadata": {},
   "source": [
    "### Exploring prototypical distortions\n",
    "\n",
    "This notebook contains code to demonstrate that generative networks such as VAEs make their outputs more prototypical.\n",
    "\n",
    "Tested with tensorflow 2.11.0 and Python 3.10.9."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cdea851",
   "metadata": {},
   "source": [
    "#### Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35353ac0",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt\n",
    "!pip install umap-learn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ec94c58",
   "metadata": {},
   "source": [
    "#### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd26f10",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from distortions_utils import *\n",
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "from random import randrange\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.backends.backend_pdf\n",
    "from config import dims_dict\n",
    "from generative_model import models_dict\n",
    "import matplotlib\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "from umap import UMAP\n",
    "from tensorflow.keras import layers\n",
    "import tensorflow.keras.backend as K\n",
    "from utils import display\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# set tensorflow random seed to make outputs reproducible\n",
    "tf.random.set_seed(123)\n",
    "np.random.seed(123)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bf6fb79",
   "metadata": {},
   "source": [
    "#### Measuring intra-class variation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9badb650",
   "metadata": {},
   "source": [
    "Load MNIST VAE trained previously and some test data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd2f65e",
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "mnist_digits = np.concatenate([x_test], axis=0)\n",
    "mnist_labels = np.concatenate([y_test], axis=0)\n",
    "mnist_digits = np.expand_dims(mnist_digits, -1).astype(\"float32\") / 255\n",
    "\n",
    "encoder, decoder = build_encoder_decoder_small(latent_dim=20)\n",
    "encoder.load_weights(\"model_weights/mnist_encoder.h5\")\n",
    "decoder.load_weights(\"model_weights/mnist_decoder.h5\")\n",
    "vae = VAE(encoder, decoder, kl_weighting=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f3ca454",
   "metadata": {},
   "source": [
    "Get latents and outputs before and after recall and optionally plot latent spaces.\n",
    "\n",
    "Note that no noise is applied in check_generative_recall() function from distortions_utils as this would invalidate the comparison of the variances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8758c4cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = mnist_digits\n",
    "all_pixels = []\n",
    "all_latents = []\n",
    "\n",
    "latents = vae.encoder.predict(test_data)\n",
    "latent_umap = UMAP(n_components=2, min_dist=1, n_neighbors=20)\n",
    "latent_umap.fit(latents[0])\n",
    "pixel_umap = UMAP(n_components=2, min_dist=1, n_neighbors=20)\n",
    "pixel_umap.fit(test_data.reshape(test_data.shape[0], 784))\n",
    "\n",
    "for i in range(2):    \n",
    "    all_pixels.append(test_data)\n",
    "    test_data, latents = check_generative_recall(vae, test_data, mnist_labels, latent_umap, pixel_umap, \n",
    "                                                 displaybool=False, n=3000)\n",
    "    all_latents.append(latents)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27b6d9e9",
   "metadata": {},
   "source": [
    "Calculate the total variance across each MNIST class before and after recall:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc99c01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "variances = {}\n",
    "\n",
    "for time_step in range(0,2):\n",
    "    variances[time_step] = []\n",
    "    for i in range(0,10):\n",
    "        px = all_pixels[time_step][0:10000]\n",
    "        inds = np.where((mnist_labels[0:10000]==i))\n",
    "        px_for_digit = px[inds][0:500]\n",
    "\n",
    "        # Reshape the images into 1D vectors (n, 784)\n",
    "        reshaped_images = px_for_digit.reshape((px_for_digit.shape[0], -1))\n",
    "        # Calculate variance per pixel of images, giving array of shape (784,)\n",
    "        variance_vec = np.var(reshaped_images, axis=0)\n",
    "        variances[time_step].append(variance_vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e3f95a7-e2ab-467a-93c2-14387e77373d",
   "metadata": {},
   "outputs": [],
   "source": [
    "reshaped_images.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f94b4b6",
   "metadata": {},
   "source": [
    "Plot intra-class variation before and after recall for each MNIST class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "734f1526-6804-49c4-92bb-15a33300f0c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "[np.std(v) / np.sqrt(len(v)) for v in variances[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d184104d",
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 12})\n",
    "\n",
    "labels = range(10)\n",
    "before_means = [np.mean(v) for v in variances[0]]\n",
    "after_means = [np.mean(v) for v in variances[1]]\n",
    "before_sem = [np.std(v) / np.sqrt(len(v)) for v in variances[0]]\n",
    "after_sem = [np.std(v) / np.sqrt(len(v)) for v in variances[1]]\n",
    "\n",
    "x = np.arange(len(labels))\n",
    "width = 0.4 \n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7,4))\n",
    "\n",
    "rects1 = ax.bar(x - width/2, before_means, width, yerr=before_sem, capsize=5, label='Inputs', color='red', alpha=0.5)\n",
    "rects2 = ax.bar(x + width/2, after_means, width, yerr=after_sem, capsize=5, label='Outputs', color='blue', alpha=0.5)\n",
    "\n",
    "ax.set_ylabel('Mean variance per pixel')\n",
    "ax.set_title('Intra-class image variation')\n",
    "plt.xticks(x)\n",
    "plt.ylim(0, 0.073)\n",
    "ax.legend()\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.savefig('misc_plots/mnist_variance.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a7fdfff-6e3b-4eec-89e7-3b0366370890",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming variances[0] and variances[1] are your data sets for 'before' and 'after'\n",
    "before_data = variances[0]\n",
    "after_data = variances[1]\n",
    "\n",
    "# Creating a figure and axis\n",
    "fig, ax = plt.subplots(figsize=(7, 4))\n",
    "\n",
    "# Positioning of the boxes\n",
    "positions = np.arange(len(before_data)) * 2\n",
    "width = 0.65  # Width of the boxes\n",
    "\n",
    "# Creating the box plot\n",
    "box1 = ax.boxplot(before_data, positions=positions - width/2, widths=width, patch_artist=True, \n",
    "                  boxprops=dict(facecolor='red', alpha=0.5), medianprops=dict(color='black'),\n",
    "                 showfliers=False, whis=[10,90])\n",
    "box2 = ax.boxplot(after_data, positions=positions + width/2, widths=width, patch_artist=True, \n",
    "                  boxprops=dict(facecolor='blue', alpha=0.5), medianprops=dict(color='black'),\n",
    "                 showfliers=False, whis=[10,90])\n",
    "\n",
    "#ax.set_ylabel('Variance per pixel', fontsize=14)\n",
    "ax.tick_params(axis='y', labelsize=16)\n",
    "plt.ylim(0, 0.235)\n",
    "#ax.set_title('Intra-class image variation')\n",
    "ax.set_xticks(positions)\n",
    "ax.set_xticklabels(range(1, len(before_data) + 1), fontsize=16)\n",
    "legend = ax.legend([box1[\"boxes\"][0], box2[\"boxes\"][0]], ['Inputs', 'Outputs'], fontsize=16)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.savefig('misc_plots/mnist_variance_boxplot.pdf')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "690e2fde-cc4a-45d5-b6ed-62eee68b785d",
   "metadata": {},
   "source": [
    "#### Statistical analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c45ec7cf-c948-405f-a4fc-1351380e5cb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import ttest_rel\n",
    "\n",
    "flattened_before_data = [item for sublist in before_data for item in sublist]\n",
    "flattened_after_data = [item for sublist in after_data for item in sublist]\n",
    "\n",
    "res = ttest_rel(flattened_before_data, flattened_after_data)\n",
    "print(res)\n",
    "print(res.confidence_interval())\n",
    "\n",
    "mean_difference = np.mean(np.array(flattened_after_data) - np.array(flattened_before_data))\n",
    "std_dev_difference = np.std(np.array(flattened_after_data) - np.array(flattened_before_data))\n",
    "cohens_d = mean_difference / std_dev_difference\n",
    "print(f\"Cohen's d: {cohens_d}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
